1
超越逐元素操作:迈向分块矩阵运算
AI023Lesson 9
00:00

在之前的课程中,我们重点学习了 逐元素操作 (例如矩阵上的基础 ReLU 操作)。这些操作属于 内存受限 因为 GPU 花费在将数据从高带宽内存(HBM)移动到寄存器上的时间,远多于执行数学计算的时间。

1. 为什么 GEMM 至关重要

通用矩阵乘法(GEMM)的计算复杂度为 $O(N^3)$,但仅需 $O(N^2)$ 的内存访问。这使我们能够利用巨大的算术吞吐量来隐藏内存延迟,因此它成为大语言模型的“核心心跳”。

2. 二维内存表示

物理内存是一维的。为了表示二维张量,我们使用 步幅(Strides)。一个常见的生产环境陷阱是 假设张量是连续的。如果你在指针计算中混淆了行与列的步幅,就会访问到‘幽灵’数据或引发内存违规。

3. 分块泛化

Triton 通过从 单个指针 转变为 指针块。通过使用二维分块(例如 $16 \times 16$),我们能充分利用高速缓存中的 数据复用 ,使数据在高速缓存中保持‘热态’,以便在写回全局内存前进行融合操作,如偏置加法或激活函数计算。

1D 线性布局2D 分块布局
main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>